########################################################
## Replication of Main Document
########################################################
rm(list=ls())
start_time <- Sys.time()
########################################################
## Package loading
########################################################
library(BridgeChange)
library(devtools)  
library(coda)
library(plm)
library(pforeach)
library(dplyr)
library(colorspace)
library(monomvn)
library(glmnet)
library(MASS)
library(MCMCpack)
library(pcse)
library(pastecs)
library(xtable)
library(bayesplot)
library(dotwhisker)
library(broom)
library(stableGR)
library(tidyverse)
library(lmtest)  #
source("helper.R")


########################################################
## Figure 1 (a) ## graph of changepoint
########################################################

## Create a fake univariate dataset (or use your own)
set.seed(1999)
N <- 300
G <- 3
mu1 <- seq(1, 3, length=G)
mu2 <- seq(3, 1, length=G)
mu3 <- seq(0, 3, length=G)
mu4<- seq(3, 0, length=G)

noise=0.1
mat1 <- mat2 <- mat3 <- matrix(NA, G, N)
mu31 <- c(5, -5, 5)
mu32 <- c(-5, -5, 5)
mu33 <- c(-5, 5, -5)

sigmoid <- function(x , b=4){1/(1 + exp(-b*x))}
for(g in 1:G){
    mat3[g,] = c(rnorm(N/2, sigmoid(seq(-5, mu31[g], length=N/2)), noise),
                 rnorm(N/2, sigmoid(seq(mu31[g], mu32[g], length=N/2)), noise))
}
col.scheme <- diverge_hcl(G, h=c(225, 330), l = c(20, 60))
x <- 1:N
pdf(file = paste0("changepoint_type.pdf"),
    width=10, height = 8, family="sans")
par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01)
plot(x, mat3[1,], type="l", lwd=0.5, ylim = range(mat3), axes=FALSE,
     ylab="Coefficients", xlab="Time")
axis(1,labels=FALSE); axis(2,labels=FALSE); ## box()
for(g in 1:G){
    lines(x, mat3[g,], type="l", lwd=0.5, col= col.scheme[g])
}
text(x=150, y = min(mat3), "Parameter group B", pos=3)
text(x=275, y = max(mat3), "Parameter group C", pos=1)
## mtext("(3) Divergent Changes", 3)
abline(v = c(75, 225), lwd=25, col=NetworkChange:::addTrans("grey60", 100))
legend("toplef", c("Parameter group A", "Parameter group B", "Parameter group C"),
       bty="n", lty=1, col = col.scheme)
dev.off()



########################################################
## Simulation 
########################################################
set.seed(39212)
n_chain <- 3
mcmc <- 1000
burn <- 1000
verbose <- 500
thin <- 1

t_time   <- c(30, 60)
n_unit   <- c(10, 20)
k_cov    <- c(20, 30)
n_break  <- c(0, 1, 2)
s_sparse <- c(1)

## 24 variations
sim_set <- expand.grid(n_unit, t_time, k_cov, n_break, s_sparse)
colnames(sim_set) <- c("N","T","K","B","S")

## sim_set <- expand.grid(n_unit, t_time, k_cov, n_break, s_sparse)
n.sim <- nrow(sim_set)

rmse.report <- matrix(NA, n.sim, 4) ## c(bias.ols, bias0, bias1, bias2)
gr.report <- waic.report <- matrix(NA, n.sim, 3) ## c(bias.ols, bias0, bias1, bias2)
state.report <-  matrix(NA, n.sim, max(t_time))

fixed = TRUE

########################################################
## Simulation Starts!
########################################################
sim.name = paste0(lubridate::today(), "_common_twoway_smallpanel_k", k_cov[1], "_n", n_unit[1])
## create a directory to save simulation outputs
dir.create(file.path(paste0("sim_fixed", sim.name)))

for(cl in 1:nrow(sim_set)){
    effect = "twoways" 
    model = "within"

    N <- sim_set[cl, 1] 
    T <- sim_set[cl, 2] 
    K <- sim_set[cl, 3] 
    B <- sim_set[cl, 4]  
    S <- sim_set[cl, 5]
    Q <- 2; ## for random effects

    cat("\n==================================\n")
    cat("simulation = ", cl, "\n")
    cat("setup (T, N, K, B, S) = ", unlist(sim_set[cl,]), "\n")
    cat("==================================\n")
    
    ## generate data ==================
    NT <- N * T; 
    break.point <- round(T / 2)
    if(B == 2){
        break.point <- c(round(T/3), T - round(T/3))
    }
    break.sigma <- 1; 
    
    ## generate covariates
    x <- as.matrix(scale(abs(gen_cov(N, T, K)))) ## NT by K+1
    K_full <- ncol(x)
    
    ## ground truth
    mu1 <- 2
    K.ground <- ifelse(fixed, K, K + 1) ## if random, intercept is added.
    if (B == 0) {
        true.beta <- gen_beta(K.ground, mu = mu1, tau = 1, pattern = S, n_break = B)
        true.sigma <- sqrt(2)
        true.D     <- diag(runif(Q), Q)
        
    }else if (B == 1) {
        out <- gen_beta(K.ground, mu = mu1, tau = 1, pattern = S, n_break = B)
        
        true.beta1   <-  out[[1]]
        true.beta2   <-  out[[2]]
        true.beta <- list(true.beta1, true.beta2)
        true.sigma <- list(sqrt(2), sqrt(3));
        true.D     <- list(diag(runif(Q), Q), diag(runif(Q), Q));
 
    } else {
        out <- gen_beta(K.ground, mu = mu1, tau = 1, pattern = S, n_break = B)
        true.beta1   <-  out[[1]]
        true.beta2   <-  out[[2]]
        true.beta3   <-  out[[3]]
        true.beta <- list(true.beta1, true.beta2, true.beta3)
        true.sigma <- list(sqrt(2), sqrt(3), sqrt(2));
        true.D     <- list(diag(runif(Q), Q), diag(runif(Q), Q), diag(runif(Q), Q));
        
    }
    
    epsilon    <- 0
    
    
    
########################################################
    ## DGP
########################################################
    ## x <- gen_cov(N, T, K)
    out <- dgp.break.common(x, N = N, T = T, Q = Q, B = B,
                            true.beta, true.sigma, true.D,
                            break.point = break.point,  break.sigma = 1,
                            fixed=TRUE,  print = FALSE )

    ## true state and beta
    
    
    ## raw data
    y <- out$y
    X <- out$X
    subject.id <- out$subject.id;
    time.id <- out$time.id
    time.dummy <- out$time.dummy
    data = as.data.frame(cbind(y, X, subject.id, time.id))
    true.beta.time <- out$true.beta.time
    coplot(y ~ time.id|subject.id, data=data)

    
########################################################
    ## Estimation
########################################################

    W <- matrix(0, length(y), 1)
    
    measurevar <- "y"
    groupvars  <- colnames(X)
    
    ## This creates the appropriate string:
    formula <- as.formula(paste(measurevar, paste(groupvars, collapse=" + "), sep=" ~ "))
    index = c('subject.id', 'time.id')
    ## model.matrix <- getFromNamespace("model.matrix.pFormula", "plm")
    pdata    <- plm::pdata.frame(data, index)
    ## pformula <- plm::pFormula(formula)
    
    plm.X <- model.matrix(formula, pdata, rhs = 1, model = model, effect = effect)
    plm.y <- plm:::pmodel.response(formula, pdata, model = model, effect = effect)
    effect = "individual"
    model = "within"
    ols <- plm(formula, pdata, model = model, effect = effect)
    ols.beta <- ols$coef
    
    unscaled.Y <- plm.y
    unscaled.X <- plm.X
    subject.id <- as.numeric(as.factor(data[,index[1]]))
    time.id    <- as.numeric(as.factor(data[,index[2]]))
    c0 = .1; d0 = .1; r0 =  1; R0 = 1;
    
    sim_out <- pforeach(i = 1:n_chain, .cores = n_chain, .seed = 3821, .c = "list")({
        test0 <- BridgeMixedPanel(subject.id = subject.id, time.id = time.id,
                                  standardize = FALSE, 
                                  mcmc = mcmc, burn = burn, thin = thin, verbose = verbose,
                                  y = y, X = X, W = W, n.break = 0, 
                                  r0 = r0, R0 = R0, c0=c0, d0=d0,
                                  alpha.MH = TRUE,
                                  unscaled.Y= y, unscaled.X = X, Waic = TRUE)
        
        test1 <- BridgeMixedPanel(subject.id = subject.id, time.id = time.id,
                                  standardize = FALSE, 
                                  mcmc = mcmc, burn = burn, thin = thin, verbose = verbose,
                                  y = y, X = X, W = W, n.break = 1, 
                                  r0 = r0, R0 = R0, c0=c0, d0=d0,
                                  alpha.MH = TRUE,
                                  unscaled.Y= y, unscaled.X = X, Waic = TRUE)
        ## plotState(test1)
        
        test2 <- BridgeMixedPanel(subject.id = subject.id, time.id = time.id,
                                  standardize = FALSE, 
                                  mcmc = mcmc, burn = burn, thin = thin, verbose = verbose,
                                  y = y, X = X, W = W, n.break = 2, 
                                  r0 = r0, R0 = R0, c0=c0, d0=d0,
                                  alpha.MH = TRUE,
                                  unscaled.Y= y, unscaled.X = X, Waic = TRUE)
        ## plotState(test2 )
        list("test0" = test0, "test1" = test1, "test2" = test2)
    })
    ## stack the result
    test0l <- mcmc.list(lapply(sim_out, function(x) x[[1]][,grep("beta", colnames(x[[1]]))]))
    test1l <- mcmc.list(lapply(sim_out, function(x) x[[2]][,grep("beta", colnames(x[[2]]))]))
    test2l <- mcmc.list(lapply(sim_out, function(x) x[[3]][,grep("beta", colnames(x[[3]]))]))
    
    ## further stack the result (from multiple chains)
    test0 <- do.call("rbind", test0l)
    test1 <- do.call("rbind", test1l)
    test2 <- do.call("rbind", test2l)
    
    est.beta0 <- test0[, grep("beta", colnames(test0))]
    est.beta1 <- test1[, grep("beta", colnames(test1))]
    est.beta2 <- test2[, grep("beta", colnames(test2))]
    
    ## m = 0
    mat.beta0 <- matrix(rep(apply(est.beta0, 2, mean), T), dim(est.beta0)[2], T)
    
    ## m = 1
    mu.beta1 <- apply(est.beta1, 2, mean)
    mu.beta11 <- mu.beta1[grep("regime1", names(mu.beta1))]
    mu.beta12 <- mu.beta1[grep("regime2", names(mu.beta1))]
    state1 <- round(apply(attr(sim_out[[3]][[2]], "s.store"), 2, mean))
    mat.beta1 <- matrix(rep(mu.beta11, T), dim(est.beta0)[2], T)
    mat.beta1[, state1==2] <- mu.beta12
    
    ## m = 2
    mu.beta2 <- apply(est.beta2, 2, mean)
    mu.beta21 <- mu.beta2[grep("regime1", names(mu.beta2))]
    mu.beta22 <- mu.beta2[grep("regime2", names(mu.beta2))]
    mu.beta23 <- mu.beta2[grep("regime3", names(mu.beta2))]
    state2 <- round(apply(attr(sim_out[[3]][[3]], "s.store"), 2, mean))
    mat.beta2 <- matrix(rep(mu.beta11, T), dim(est.beta0)[2], T)
    mat.beta2[, state2==2] <- mu.beta22
    mat.beta2[, state2==3] <- mu.beta23
    
    ## plm estimates: twoway fe
    ## effect = "twoways": do not use it
    mat.ols <- matrix(rep(ols.beta, T), dim(est.beta0)[2], T)
    
    rmse0 <- rmse.fn(true.beta.time - mat.beta0)
    rmse1 <- rmse.fn(true.beta.time - mat.beta1)
    rmse2 <- rmse.fn(true.beta.time - mat.beta2)
    rmse.ols <- rmse.fn(true.beta.time - mat.ols)
    
    ## Approximate convergence is diagnosed when the upper limit is close      to 1
    ## Potential scale reduction factors:
    gr <- c(stableGR::stable.GR(test0l)$mpsrf, stableGR::stable.GR(test1l)$mpsrf, stableGR::stable.GR(test2l)$mpsrf)
    
    ## state.report
    rmse.report[cl, ] <- c(rmse.ols, rmse0, rmse1, rmse2)
    gr.report[cl, ] <- gr
    if(B==0){
        
    }else if(B == 1){
        state.report[cl, 1:length(state1)] <- state1
    }else{            
        state.report[cl, 1:length(state2)] <- state2
    }
    
    ## waic list
    waic.report[cl, ]<- c(mean(unlist(lapply(sim_out, function(x) {attr(x[[1]],  "Waic.out")[1]}))),
                          mean(unlist(lapply(sim_out, function(x) {attr(x[[2]],  "Waic.out")[1]}))),
                          mean(unlist(lapply(sim_out, function(x) {attr(x[[3]],  "Waic.out")[1]}))))

########################################################
    ## Figure 4: overdetection plot
########################################################
    ## Figure 4(a)
    if(cl == 16){
        ## Graph 1: hidden state
       pdf(file = "Figure4a.pdf", width=12, height = 8, family="sans")
        par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01)
        par(mfrow=c(2, 2))
        dotplotRegime(sim_out[[2]][[2]], hybrid=FALSE, location.bar=30, x.location="none",
                  text.cex=0.8,  main="")
        dotplotRegime(sim_out[[2]][[3]], hybrid=FALSE, location.bar=30, x.location="none",
                  text.cex=0.8,  main="")
        
        plotState(sim_out[[2]][[2]], legend.control =c(12, 0.85), main="Break 1")
        plotState(sim_out[[2]][[3]], legend.control =c(12, 0.85), main="Break 2")
        dev.off()

    }
     ## Figure 4(b)
    if(cl == 24){
        ## Graph 1: hidden state
        pdf(file = "Figure4b.pdf", width=12, height = 8, family="sans")
        par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01)
        par(mfrow=c(2, 2))
        dotplotRegime(sim_out[[2]][[2]], hybrid=FALSE, location.bar=30, x.location="none",
                  text.cex=0.8,  main="")
        dotplotRegime(sim_out[[2]][[3]], hybrid=FALSE, location.bar=30, x.location="none",
                  text.cex=0.8,  main="")
        
        plotState(sim_out[[2]][[2]], legend.control =c(12, 0.85), main="Break 1")
        plotState(sim_out[[2]][[3]], legend.control =c(12, 0.85), main="Break 2")
        dev.off()

    }

########################################################
    ## Store
########################################################
    ## save file here
    saveRDS(rmse.report, file = paste0("sim_fixed", sim.name, "/rmse.rds"))
    saveRDS(gr.report, file = paste0("sim_fixed", sim.name, "/gr.rds"))
    saveRDS(waic.report, file = paste0("sim_fixed", sim.name, "/waic.rds"))
    saveRDS(state.report, file = paste0("sim_fixed", sim.name, "/state.rds"))
    
    
    cat("\n==================================\n")
    cat("simulation = ", cl, " is done ! \n")
    cat("==================================\n")
    
}

## load estimates
rmse <- readRDS(file = paste0("sim_fixed", sim.name, "/rmse.rds"))
gr <- readRDS(file = paste0("sim_fixed", sim.name, "/gr.rds"))
waic <- readRDS(file = paste0("sim_fixed", sim.name, "/waic.rds"))
state <- readRDS(file = paste0("sim_fixed", sim.name, "/state.rds"))

## pull the ground truth
true.m <- sim_set$B
true.t <- sim_set$T
true.n <- sim_set$N
true.k <- sim_set$K
true.state <- list()
for(i in 1:n.sim){
    if(true.m[i] == 1){
        true.state[[i]] <- c(rep(1,  true.t[i]/2), rep(2,  true.t[i]/2))
    }else if(true.m[i] == 2){
        break.point <- c(round(true.t[i]/3), true.t[i]- round(true.t[i]/3))
        true.state[[i]] <- c(rep(1,  break.point[1]), rep(2,  break.point[2] - break.point[1]),
                             rep(3,  true.t[i] - break.point[2]))
    }else{
     cat("m = ", true.m[i], "\n")   
    }
}

##############################
## Figure 2
##############################
## Figure 2 (a): rmse
pdf(file = "Figure2a.pdf", width=10, height = 3.5, family="sans")
par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01); par(mfrow=c(1, 3))
counter <- 1
for(i in 1:3){##
    df <- rmse[true.m == i-1, ]
    plot(df[1, ], ylim = range(df), pch=19, axes=FALSE, type="n", ylab="RMSE", xlab="Model")
    for(j in 1:nrow(df)){
        lines(df[j,], col="grey70")
        points(df[j,], pch=19, cex=1)
        if(true.m[counter] == 0){
            ## plm
            points(true.m[counter]+1, df[j, true.m[counter]+1],
                   pch=19, col=NetworkChange:::addTrans("brown", 100), cex=5)
        }
        ## break 
        points(true.m[counter]+2, df[j, true.m[counter]+2],
               pch=19, col=NetworkChange:::addTrans("brown", 100), cex=5)
        text(true.m[counter]+2, df[j, true.m[counter]+2],
             paste0("T=", true.t[counter], ", N=", true.n[counter], ", K=", true.k[counter]),
             pos = 3 - (true.m[counter] - 1))
        counter <- counter + 1
        cat("counter  = ", counter, "\n")
    }
    mtext(paste0("Break number = ", true.m[counter-1]), 3)
    axis(1, at = 1:ncol(df), labels=c("FE", "HMBB break 0","HMBB break 1","HMBB break 2")); axis(2);  box()
}
dev.off()

## Figure 2 (b): waic
pdf(file = "Figure2b.pdf", width=10, height = 3.5, family="sans")
par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01)
par(mfrow=c(1, 3))
counter <- 1
for(i in 1:3){##
    df <- waic[true.m == i-1, ]
    plot(df[1,], ylim = range(df), pch=19, axes=FALSE, type="n", ylab="WAIC", xlab="Model")
    for(j in 1:nrow(df)){
        lines(df[j,], col="grey70")
        points(df[j,], pch=19, cex=1)
        points(true.m[counter]+1, df[j, true.m[counter]+1], pch=19, col=NetworkChange:::addTrans("brown", 100), cex=5)
        text(true.m[counter]+1, df[j, true.m[counter]+1], paste0("T=", true.t[counter], ", N=", true.n[counter], ", K=", true.k[counter]),
             pos = 3 - (true.m[counter] - 1))
        counter <- counter + 1
        cat("counter  = ", counter, "\n")
    }
    mtext(paste0("Break number = ", true.m[counter-1]), 3)
    axis(1, at = 1:3, labels=c("break 0","break 1","break 2")); axis(2);  box()
}
dev.off()
    
## Figure 3 (a): hidden state
pdf(file = "Figure3a.pdf", width=10, height = 5, family="sans")
par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01)
par(mfrow=c(2, 2))
## subplot 1
plot(1:max(true.t[true.t==unique(true.t)[1] & true.m == 1]), ylim=c(0.5, 2.5), type="n", axes=FALSE, ylab="Hidden state", xlab="Time");
axis(1); axis(2, at=1:3); box();mtext(paste0("Hidden State Recovery (m = 1, T = ", unique(true.t)[1], ")"), 3)
for(i in which(true.t==unique(true.t)[1]& true.m == 1)){
    lines(jitter(state[i, 1:length(true.state[[i]])]), col="grey60", cex=0.2)
    lines(true.state[[i]], col="grey20", lwd=2)}
legend("topleft", legend=c("True", "Estimated"), col=c("grey20", "grey60"),
       lty=1, lwd=c(2, 1), bty="n")

## subplot 2
plot(1:max(true.t[true.t==unique(true.t)[1] & true.m == 2]), ylim=c(0.5, 3.5), type="n", axes=FALSE, ylab="Hidden state", xlab="Time");
axis(1); axis(2, at=1:3); box();mtext(paste0("Hidden State Recovery (m = 2, T = ", unique(true.t)[1], ")"), 3)
for(i in which(true.t==unique(true.t)[1]& true.m == 2)){
    lines(jitter(state[i, 1:length(true.state[[i]])]), col="grey60", cex=0.2)
    lines(true.state[[i]], col="grey20", lwd=2)}
legend("topleft", legend=c("True", "Estimated"), col=c("grey20", "grey60"),
       lty=1, lwd=c(2, 1), bty="n")

## subplot 3
plot(1:max(true.t[true.t==unique(true.t)[2] & true.m == 1]), ylim=c(0.5, 2.5), type="n", axes=FALSE, ylab="Hidden state", xlab="Time");
axis(1); axis(2, at=1:3); box();mtext(paste0("Hidden State Recovery (m = 1, T = ", unique(true.t)[2], ")"), 3)

for(i in which(true.t==unique(true.t)[2]& true.m == 1)){
    lines(jitter(state[i, 1:length(true.state[[i]])]), col="grey60", cex=0.2)
    lines(true.state[[i]], col="grey20", lwd=2)}
legend("topleft", legend=c("True", "Estimated"), col=c("grey20", "grey60"),
       lty=1, lwd=c(2, 1), bty="n")

## subplot 4
plot(1:max(true.t[true.t==unique(true.t)[2] & true.m == 2]), ylim=c(0.5, 3.5),  type="n", axes=FALSE, ylab="Hidden state", xlab="Time");
axis(1); axis(2, at=1:3); box();mtext(paste0("Hidden State Recovery (m = 2, T = ", unique(true.t)[2], ")"), 3)
for(i in which(true.t==unique(true.t)[2]& true.m == 2)){
    lines(jitter(state[i, 1:length(true.state[[i]])]), col="grey60", cex=0.2)
    lines(true.state[[i]], col="grey20", lwd=2)}
legend("topleft", legend=c("True", "Estimated"), col=c("grey20", "grey60"),
       lty=1, lwd=c(2, 1), bty="n")
dev.off()


## Figure 3 (b): diagnostics
pdf(file = "Figure3b.pdf", width=10, height = 5, family="sans")
par (mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01)
plot(sapply(1:nrow(sim_set), function(i)gr[i, true.m[i]+1]), pch=19,
     type="p", axes=FALSE, ylab="Rhat", xlab="Simulation")
axis(1); axis(2); abline(h=c(1.0, 1.1), lty=3, col="grey60");
box();mtext("Gelman-Rubin Convergence Diagnostics", 3)
dev.off()


end_time <- Sys.time()
print(end_time - start_time)
